"""
The code is released exclusively for review purposes with the following terms:
PROPRIETARY AND CONFIDENTIAL. UNAUTHORIZED USE, COPYING, OR DISTRIBUTION OF THE 
CODE, VIA ANY MEDIUM, IS STRICTLY PROHIBITED. BY ACCESSING THE CODE, THE 
REVIEWERS AGREE TO DELETE THEM FROM ALL MEDIA AFTER THE REVIEW PERIOD IS OVER.
"""
""" Create LINEX explanations for the various examples. """
import os
os.environ["OMP_NUM_THREADS"] = "1"

import numpy as np
import sys
sys.path.append("../utilities/")
from time import time

from joblib import Parallel, delayed
from sklearn.utils import check_random_state
import yaml
import pickle
from threadpoolctl import threadpool_limits, threadpool_info

# fname_lime_exp
from utils import (fname_env_perts, fname_exp,
                    fname_base_perts, fname_preds, compute_weights,
                    create_dir_if_not_exist)
from helpers import lrg_lsq_sparse, lrg_lsq, lrg_lsq_sparse_simple

# Pass arguments and run the code
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--config_fname")
parser.add_argument("--dataset_key")
parser.add_argument("--model_key")
parser.add_argument("--pert_key")
parser.add_argument("--start_ex")
parser.add_argument("--end_ex")
parser.add_argument("--num_cpus")
parser.add_argument("--maple_leaves")
args = parser.parse_args()
args.maple_leaves = int(args.maple_leaves)

# Load the config file
config = yaml.load(open(
            os.path.join("config", args.config_fname)),
            Loader=yaml.FullLoader)
# config['Base_Perturbations']["cnt"]=100
#config['Base_Perturbations']["cnt"]=1000

# Load the data
dirname = os.path.join("data", args.dataset_key, "perturbations")
env_pert_fname = fname_env_perts(config, "Env_Perturbations",
                    args.pert_key, args.dataset_key)+".pkl"
env_perturbations = pickle.load(open( os.path.join(dirname, env_pert_fname), "rb" ) )

base_pert_fname = fname_base_perts(config, 
                    args.pert_key, args.dataset_key)+".pkl"
base_perturbations = pickle.load(open( 
                        os.path.join(dirname, base_pert_fname), "rb" ) )

dirname = os.path.join("data", args.dataset_key, "predictions")
preds_fname = fname_preds(config, args.pert_key, args.model_key, args.dataset_key)+".pkl"
y_pred_perts = pickle.load(open( os.path.join(dirname, preds_fname), "rb" ) )

dirname = os.path.join("data", args.dataset_key, "explanations")
lime_exp_fname = fname_exp(config, "LIME", "Env_Perturbations",
                    args.pert_key, args.model_key,
                    args.dataset_key)+".pkl"
lime_exps = pickle.load(open( os.path.join(dirname, lime_exp_fname), "rb"))
lime_exps_envs = lime_exps["lime_envs"]

# Compute coeffs for base as well as individual envs
n_data_all = len(base_perturbations["indices"])
n_perts_env = config["Env_Perturbations"]["cnt"]
num_envs = config["Env_Perturbations"]["num_envs"]
p = base_perturbations["samp_perts_exp"][0].shape[1]
samp_perts_exp = base_perturbations["samp_perts_exp"]
kernel_width = config["Env_Perturbations"]["kernel_width"]
samp_inds_env = env_perturbations["samp_inds_env"]
# weights_env = env_perturbations["weights_env"]
normalize_weights = config["Env_Perturbations"]["normalize_weights"]
num_nonzeros = config["LIME"]["non_zeros"]
max_value_coef = np.max(np.abs(lime_exps_envs))
# Non-zero inds chosen from the env explanations
envs_nzinds = np.abs(lime_exps_envs).sum(axis=2) > 0.0

# filename suffix
if args.start_ex is None and args.end_ex is None:
    args.start_ex = 0
    args.end_ex = n_data_all
    fname_suffix = ""
else:
    if args.start_ex is None:
        args.start_ex = 0
        args.end_ex = int(args.end_ex)
    elif args.end_ex is None:
        args.start_ex = int(args.start_ex)
        args.end_ex = n_data_all
    else:
        args.start_ex = int(args.start_ex)
        args.end_ex = int(args.end_ex)
    
    fname_suffix = "_"+str(args.start_ex)+"_"+str(args.end_ex)
# print(fname_suffix)

if args.num_cpus is not None:
    NUM_CPUS = int(args.num_cpus)
else:
    NUM_CPUS = 1

if num_nonzeros >= p:
    set_ridge = True
else:
    set_ridge = False

if "restrict_coef" in config["LINEX"] and config["LINEX"]["restrict_coef"]:
    restrict_coef = True
else:
    restrict_coef = False
# print(restrict_coef)

print(max_value_coef)
# LRG options
lrg_config = {
    "num_iters": 100,
    "bound": max_value_coef + 1e-6,
    "max_iter": 1,
    "rescale_by_weights": True,
    "fit_intercept": False,
    "ridge": set_ridge,
    "ridge_penalty_multiplier": 1e-1,
    "l1": not set_ridge,
    "num_nonzeros": 5,
    "randomize_iterations": True,
    "debias": True
}

# LRG options for debias
lrg_config_debias = lrg_config.copy()
lrg_config_debias["num_iters"] = 500
lrg_config_debias["ridge"] = True
lrg_config_debias["l1"] = False

# For envs
examples_used = np.arange(args.start_ex, args.end_ex)
linex_exps = np.zeros((len(examples_used), p))

## For MAPLE VVV
if args.pert_key == "MAPLE":
    from sklearn.ensemble import RandomForestRegressor
    from sklearn.preprocessing import OneHotEncoder
        
    X_test = np.vstack([x[0] for x in base_perturbations["samp_perts_exp"]])
    y_test = np.array([y[0] for y in y_pred_perts])
    from time import time
    def train_maple_weights(X_test, y_test):
        rfr = RandomForestRegressor(min_samples_leaf=args.maple_leaves, 
                                    random_state=int(time()))
        # rfr = RandomForestRegressor(min_samples_leaf=100, random_state=int(time()))

        rfr.fit(X_test, y_test)
        return rfr
    
    # RF Regressor for MAPLE weights
    rfr = train_maple_weights(X_test, y_test)
    
    def compute_maple_weights(x):
        leaves = rfr.apply(x)
        # leaves.shape
        onehot_cats = [list(range(est.tree_.node_count)) 
                    for est in rfr.estimators_]
        leaf_enc = OneHotEncoder().fit(leaves)
        M = leaf_enc.transform(leaves)
        S = (M*M.transpose()).todense().A + 1e-5

        return S/np.max(S)


def linex_explanation(y_pred, samp_pert_exp0, idx):
    with threadpool_limits(limits={"blas": 1, "openmp": 1}):
        # Remove all zero columns and columns with coefs 
        nzinds0 = ~np.all(samp_pert_exp0 == 0, axis=0)
        if restrict_coef:
            nzinds = nzinds0 & envs_nzinds[idx]
        else:
            nzinds = nzinds0
        samp_pert_exp = samp_pert_exp0[:, nzinds]
        w_lrg_lsq = np.zeros(p)

        data_train_env = []
        for env in range(num_envs):
            samp_pert_exp_env = samp_pert_exp[samp_inds_env[idx, :, env], :]
            y_pred_env = y_pred[samp_inds_env[idx, :, env]]

            if args.pert_key == "MAPLE":
                local_weights_env = compute_maple_weights(
                                samp_pert_exp0[samp_inds_env[
                                            idx, :, env], :])[0]
                if normalize_weights:
                    local_weights_env = local_weights_env/local_weights_env.sum()
            else:
            # if args.pert_key == "Base_Perturbations":
                local_weights_env = compute_weights(
                            samp_pert_exp_env, 
                            distance_metric="euclidean", 
                            kernel_width=kernel_width,
                            normalize=normalize_weights)


            data_train_env.append([samp_pert_exp_env, y_pred_env, 
                    np.sqrt(local_weights_env)/sum(np.sqrt(local_weights_env))])

        # Actual LRG
        if lrg_config["l1"]:
            w_envs_current = np.zeros((num_envs, np.sum(nzinds)))
            w_lrg_lsq0, w_all_iters, mae_iters, w_envs_current = lrg_lsq_sparse_simple(
                                    data_train_env,
                                    w_envs_current,
                                    lrg_config)
            # print("l1", w_lrg_lsq0.shape)
            w_lrg_lsq[nzinds] = w_lrg_lsq0
        else:
            w_lrg_lsq0, w_all_iters, mae_iters = lrg_lsq(
                                data_train_env, 
                                lrg_config)
            w_lrg_lsq[nzinds] = w_lrg_lsq0

        # Debias if needed
        if lrg_config["l1"] and lrg_config["debias"]:
            w_lrg_lsq0[np.abs(w_lrg_lsq0) < 1e-4] = 0.0
            chosen_features = np.abs(w_lrg_lsq0) > 0.0

            if np.sum(chosen_features) > 0:
                data_train_env_debias = [[data_train_env[env][0][:, chosen_features], 
                                        data_train_env[env][1], 
                                        data_train_env[env][2]]
                                        for env in range(num_envs)]
                w_lrg_lsq_deb, w_all_iters, mae_iters = lrg_lsq(
                            data_train_env_debias, 
                            lrg_config_debias)
                w_lrg_lsq = np.zeros(p)
                w_lrg_lsq[np.where(nzinds)[0][chosen_features]] = w_lrg_lsq_deb
            else:
                # nothing to debias
                w_lrg_lsq = np.zeros(p)
    #print(idx)

    return w_lrg_lsq

st_time = time()
w_lrg_lsq_list = Parallel(n_jobs=NUM_CPUS)(
                    delayed(linex_explanation)(
                            y_pred_perts[idx],
                            base_perturbations["samp_perts_exp"][idx],
                            idx) 
                    for idx in examples_used)
linex_exps = np.vstack(w_lrg_lsq_list)
linex_time = time()-st_time


# dump the explanations
exp_fname = fname_exp(config, "LINEX", "Env_Perturbations",
                    args.pert_key, args.model_key,
                    args.dataset_key)+fname_suffix+".pkl"

linex_explanations = {"linex": linex_exps,
                      "linex_time": linex_time}
# print("exp sum ", np.abs(np.sum(linex_explanations["linex"])))
dirname = os.path.join("data", args.dataset_key, "explanations")
print(linex_exps.shape)

create_dir_if_not_exist(dirname)

pickle.dump(linex_explanations, 
    open( os.path.join(dirname, exp_fname), "wb" ) )
print(os.path.join(dirname, exp_fname))
